package com.xueyuan.wata.daph.node.spark3.dataframe.batch.connector.hbase

import com.xueyuan.wata.daph.spark3.api.node.dataframe.connector.output.DataFrameSingleOutput
import org.apache.hadoop.fs.Path
import org.apache.hadoop.hbase.client.{Connection, ConnectionFactory}
import org.apache.hadoop.hbase.spark.datasources.HBaseTableCatalog
import org.apache.hadoop.hbase.spark.{ByteArrayWrapper, FamiliesQualifiersValues, HBaseContext}
import org.apache.hadoop.hbase.tool.LoadIncrementalHFiles
import org.apache.hadoop.hbase.util.Bytes
import org.apache.hadoop.hbase.{HBaseConfiguration, TableName}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.sql.{DataFrame, Row}

import scala.util.control.Breaks._

class HBaseOutput extends DataFrameSingleOutput {
  private var hbaseContext: HBaseContext = _

  override def out(df: DataFrame): Unit = {
    val config = nodeConfig.asInstanceOf[HBaseOutputConfig]
    val stagingDir = config.stagingDir + "/" + System.currentTimeMillis().toString
    val sm = config.saveMode
    val saveMode = if (sm.nonEmpty) sm else "append"
    val catalog = config.catalog
    val hbaseOptions = config.hbaseOptions

    val hbaseConf = HBaseConfiguration.create(spark.sessionState.newHadoopConf())
    hbaseOptions.foreach { case (key, value) => hbaseConf.set(key, value) }
    val hbaseContext = new HBaseContext(spark.sparkContext, hbaseConf)

    var dfWithStringFields = df
    val colNames = df.columns

    // convert all columns type to string
    for (colName <- colNames) {
      dfWithStringFields =
        dfWithStringFields.withColumn(colName, col(colName).cast(DataTypes.StringType))
    }

    val parameters = Map(HBaseTableCatalog.tableCatalog -> catalog)
    val htc = HBaseTableCatalog(parameters)
    val tableName = TableName.valueOf(htc.namespace + ":" + htc.name)
    val columnFamily = htc.getColumnFamilies
    val hbaseConn = ConnectionFactory.createConnection(hbaseConf)

    try {
      if (saveMode.equals("overwrite")) {
        truncateHTable(hbaseConn, tableName)
      }

      def familyQualifierToByte: Set[(Array[Byte], Array[Byte], String)] = {
        if (columnFamily == null || colNames == null) {
          throw new Exception("null can't be convert to Bytes")
        }
        colNames.filter(htc.getField(_).cf != HBaseTableCatalog.rowKey).map(colName =>
          (Bytes.toBytes(htc.getField(colName).cf), Bytes.toBytes(colName), colName)).toSet
      }

      hbaseContext.bulkLoadThinRows[Row](
        dfWithStringFields.rdd,
        tableName,
        r => {
          val rawPK = new StringBuilder
          for (c <- htc.getRowKey) {
            rawPK.append(r.getAs[String](c.colName))
          }

          val rkBytes = rawPK.toString.getBytes()
          val familyQualifiersValues = new FamiliesQualifiersValues
          val fq = familyQualifierToByte
          for (c <- fq) {
            breakable {
              val family = c._1
              val qualifier = c._2
              val value = r.getAs[String](c._3)
              if (value == null) {
                break
              }
              familyQualifiersValues += (family, qualifier, Bytes.toBytes(value))
            }
          }
          (new ByteArrayWrapper(rkBytes), familyQualifiersValues)
        },
        stagingDir)

      val load = new LoadIncrementalHFiles(hbaseConf)
      val table = hbaseConn.getTable(tableName)
      load.doBulkLoad(
        new Path(stagingDir),
        hbaseConn.getAdmin,
        table,
        hbaseConn.getRegionLocator(tableName))

    } finally {
      if (hbaseConn != null) {
        hbaseConn.close()
      }

      cleanUpStagingDir(stagingDir)
    }
  }

  private def cleanUpStagingDir(stagingDir: String): Unit = {
    val stagingPath = new Path(stagingDir)
    val fs = stagingPath.getFileSystem(hbaseContext.config)
    if (!fs.delete(stagingPath, true)) {
      logger.warn(s"clean staging dir $stagingDir failed")
    }
    if (fs != null) {
      fs.close()
    }
  }

  private def truncateHTable(connection: Connection, tableName: TableName): Unit = {
    val admin = connection.getAdmin
    if (admin.tableExists(tableName)) {
      admin.disableTable(tableName)
      admin.truncateTable(tableName, true)
    }
  }

  override def getNodeConfigClass = classOf[HBaseOutputConfig]
}
